""" Robustness: what if the well's outside option is not zero?

- Do not include equilibrium effects of searching wells
- 1. Use the `old' probabilities of matching for wells (not directly observed)
- 2. Plug these into the surplus
- 3. As a byproduct, save the cutoffs (for the acceptance sets viz)

"""

#%% 1. Get the equilibrium probability of wells rematching ------------------------------
import pandas as pd
from scipy.stats import norm as norm
import numpy as np
import sys
sys.path.append('./')

from src.models_new import simulation, surplus
from src.run_scripts import utils
import statsmodels.formula.api as smf


# %% SET SOME CONSTANT ------------------------------------------------------------------
mri_max = 2.15
time = 'month'
beta = 0.99
use_empirical_state = False
polynomial_order = 2
welfare_graph = True

# For value functions
mri_grid = 5
g_grid = 10
n_grid = 5
beta_by_time = {'month': 0.99, 'fortnight': 0.995}
sim_length = 1200
max_sim_length = 120
seeds = list(range(200))
options = {
    'threads_per_worker': 1,
    'n_workers': 8
}

# %% READ IN COMPONENTS -----------------------------------------------------------------
data, delta, rho, c, weights = utils.read_in_data(
    path_moments_data='./models/smm_input/moments_empirical.csv',
    path_n_rigs="./models/first_stage/n_rigs",
    path_surplus_components="./models/surplus/surplus_components",
    path_surplus_grid='./models/surplus/surplus_grid_2_low_month.npy',
    path_df_state='./data_py/processed/states',
    path_delta='./models/smm_input/delta.csv',
    path_rho='./models/smm_input/rho.csv',
    path_df_contracts='./data_py/processed/contracts_final.csv',
    path_price_match_values='./models/price_match/price_match_values',
    path_coefs_data='./models/smm_input/coefs_data',
    path_prob_match_predict_contracts='./models/robustness/prob_match_predict_contracts',
    path_prob_match_predict='./models/robustness/prob_match_predict',
    time_period='month',
    use_myopic=True
)
#%% PUT PARAMS TOGETHER FOR SIMULATION---------------------------------------------------
params = pd.read_csv('./models/smm/params_smm_with_diff_new.csv', index_col=[0]).to_dict()['0']

params['a_0'] = np.array([1.0, 1.0, 1.0])
params['a_1'] = np.array(
    [
        params['a_1_low'],
        params['a_1_mid'],
        params['a_1_high']
    ]
)

params['denom'] = (
    norm.cdf(mri_max, loc=params['mu_0'], scale=params['sigma_0'])
    - norm.cdf(0, loc=params['mu_0'], scale=params['sigma_0'])
)

params['p_4'] = 1.0 - params['p_3'] - params['p_2']

#%% DO SIMULATION -----------------------------------------------------------------------
(
    surplus_values_by_tau_spec,
    well_outside_option_by_tau_spec
) = surplus.build_fast_surplus(
    data['match_values_by_tau_spec'],
    params,
    data['surplus_grid'],
    data['non_myopic_dict']
)
results_benchmark = simulation.do_simulation(
    data['state_data'],
    params,
    surplus_grid=data['surplus_grid'],
    surplus_values_by_tau_spec=surplus_values_by_tau_spec,
    n_rigs=data['n_rigs'],
    mri_max=2.15,
    verbose=True,
    value_zero=False
)

#%% GET THE PROB. OF EACH WELLS MATCHING IN THE DATA
# Check the # matches is the same no matter how it is computed...
df_moments = pd.DataFrame(results_benchmark['moments_by_state'])
df_shares = pd.DataFrame(
    results_benchmark['shares_by_state'],
    columns=['low', 'mid', 'high']
)
df_moments['g'] = data['state_data']['g']

# Get number of NEW matches
n_new_matches_by_spec = dict()
for spec in ['low', 'mid', 'high']:
    for i in results_benchmark['prob_new_match_by_ym']:
        n_new_matches_by_spec[(spec, i)] = 0.0
        for tau in [2, 3, 4]:
            n_new_matches_by_spec[(spec, i)] += \
                data['n_rigs'][spec] \
                * results_benchmark['prob_new_match_by_ym'][i][0][(spec, tau)].sum()
df_new_matches = pd.Series(n_new_matches_by_spec)

#%%
df_state_simple = pd.DataFrame(results_benchmark['state_simple_by_ym']).T
df_state_simple.columns = ['g', 'n_l', 'n_m', 'n_h']

#%%
for spec in ['low', 'mid', 'high']:
    # Get number of NEW matches
    df_moments[f'n_new_matches_{spec}'] = np.array(df_new_matches.xs(spec))

    # Get prob of well matching IMPLIED (note: does not take into account rejected matches)
    df_moments[f'share_{spec}'] = df_shares[spec]
    df_moments[f'n_target_{spec}'] = \
        df_moments[f'share_{spec}'] * (params['d_0'] + params['d_1'] * df_moments['g'])
    df_moments[f'p_match_well_{spec}'] = \
        df_moments[f'n_new_matches_{spec}'] / df_moments[f'n_target_{spec}']
df_moments[['g', 'n_l', 'n_m', 'n_h']] = np.array(df_state_simple[['g', 'n_l', 'n_m', 'n_h']])

#%% 1. Do OLS to get the prob. of matching at the well level
coefs_p_match_well = dict()
reg_by_spec = dict()
df_moments.loc[df_moments['p_match_well_high'] >= 1.0, 'p_match_well_high'] = 0.9999999
for spec in ['low', 'mid', 'high']:
    formula = (
        f"p_match_well_{spec} "
        "~ g + n_l + n_m + n_h"
        "+ g:g + g:n_l + g:n_m + g:n_h"
        "+ n_l:n_l + n_l:n_m + n_l:n_h"
        "+ n_m:n_m + n_m:n_h + n_h:n_h"
    )
    reg_by_spec[spec] = smf.logit(
        formula=formula,
        data=df_moments
    ).fit()
    coefs_p_match_well[spec] = reg_by_spec[spec].params

#%% 2. Read in surplus nodes and predict
prob_match_predict_by_spec = dict()
for spec in ['low', 'mid', 'high']:
    df_surplus_nodes = pd.read_csv(
        f'./models/surplus/nodes_month.txt',
        delimiter=' ',
        names=['mri', 'g', 'n_l', 'n_m', 'n_h']
    )
    prob_match_predict = reg_by_spec[spec].predict(df_surplus_nodes[['g', 'n_l', 'n_m', 'n_h']])

    # Save
    prob_match_predict.to_csv(f'./models/robustness/prob_match_predict_{spec}.csv')

    # Convert to array
    prob_match_predict_by_spec[spec] = np.array(prob_match_predict)

    # ALSO do predictions from contracts
    prob_match_predict_contracts = reg_by_spec[spec].predict(
        data['df_contracts']
        .loc[(data['df_contracts']['spec'] == spec), ['g', 'n_l', 'n_m', 'n_h']]
    )

    # Save
    prob_match_predict_contracts.to_csv(
        f'./models/robustness/prob_match_predict_contracts_{spec}.csv')
